-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add possibility to switch between APEX and AMP in Trainer #9137
Conversation
trainer.add_callback(EarlyStoppingCallback(1, 0.0001)) | ||
train_output = trainer.train() | ||
self.assertLess(train_output.global_step, 20 * 64 / 16) | ||
with tempfile.TemporaryDirectory() as tmp_dir: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test was saving things in a regression
folder and adding lots of unwanted files. Moving it to a temp folder.
try: | ||
trainer.train() | ||
except AssertionError: | ||
with tempfile.TemporaryDirectory() as tmp_dir: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Absolutely fantastic! Thank you, @sgugger!
I added a few small suggestions in the comments
This PR also removes |
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, LGTM. The new auto
is clean.
There was one nit that you agreed with but didn't integrate - but I'm fine if it remains as merged - just a potential for divergence down the road... |
Oh, which one did I miss? |
Argh, one sec - I see what happened - that wasn't what I meant - sorry for not being clear. I tried to suggest not repeating the options in the help comment - please let's have |
Fixed directly on master in this commit. |
That's perfect. Thank you, @sgugger! |
What does this PR do?
When PyTorch >= 1.6 is installed, Trainer is always using native AMP right now. This PR adds the option to switch between AMP and APEX, which can be useful:
It also simplifies a little bit the internal of Trainer with those.